import

import torch 
from fastai.vision.all import * 
import cv2 
import fastbook
from fastbook import *
from fastai.vision.widgets import *

data

path=Path('/home/khy/chest_xray/chest_xray') 
path.ls()
(#5) [Path('/home/khy/chest_xray/chest_xray/train'),Path('/home/khy/chest_xray/chest_xray/test'),Path('/home/khy/chest_xray/chest_xray/chest_xray'),Path('/home/khy/chest_xray/chest_xray/__MACOSX'),Path('/home/khy/chest_xray/chest_xray/val')]
files=get_image_files(path)
files
(#11712) [Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0766-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/NORMAL2-IM-1318-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0160-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/NORMAL2-IM-1327-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0489-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0509-0001-0002.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0761-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0416-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/NORMAL2-IM-0566-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0411-0001.jpeg')...]
dls = ImageDataLoaders.from_folder(path, train='train', valid_pct=0.2, item_tfms=Resize(224))      
dls.vocab
['NORMAL', 'PNEUMONIA']
dls.show_batch(max_n=16)
learn=cnn_learner(dls,resnet34,metrics=error_rate)
net1=learn.model[0]
net2=learn.model[1] 
net2 = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d(output_size=1), 
    torch.nn.Flatten(),
    torch.nn.Linear(512,out_features=2,bias=False))
net=torch.nn.Sequential(net1,net2)
lrnr2=Learner(dls,net,metrics=accuracy) 
lrnr2.fine_tune(200) 
epoch train_loss valid_loss accuracy time
0 0.166842 0.091861 0.967122 00:33
epoch train_loss valid_loss accuracy time
0 0.076691 0.070642 0.973954 00:33
1 0.065596 0.065189 0.976943 00:32
2 0.063810 0.060881 0.977797 00:32
3 0.058133 0.055606 0.979505 00:33
4 0.047295 0.051751 0.982494 00:32
5 0.049507 0.061955 0.975235 00:32
6 0.040383 0.048890 0.982494 00:33
7 0.037072 0.038793 0.985483 00:32
8 0.029895 0.035411 0.988044 00:33
9 0.024122 0.032279 0.988471 00:32
10 0.022319 0.030799 0.990606 00:32
11 0.022883 0.029063 0.990606 00:32
12 0.018799 0.024217 0.993595 00:32
13 0.018655 0.026862 0.991887 00:32
14 0.017203 0.025556 0.991460 00:32
15 0.012168 0.028741 0.991887 00:32
16 0.013291 0.021540 0.991460 00:32
17 0.013113 0.023177 0.993595 00:32
18 0.014589 0.023715 0.993168 00:32
19 0.010889 0.027784 0.992314 00:32
20 0.010598 0.028819 0.992314 00:32
21 0.013652 0.023543 0.993168 00:32
22 0.010165 0.021542 0.993168 00:32
23 0.011329 0.024496 0.992314 00:32
24 0.009473 0.019847 0.992314 00:32
25 0.007470 0.022198 0.990179 00:32
26 0.007615 0.017968 0.995303 00:32
27 0.006131 0.022273 0.995730 00:32
28 0.008292 0.032437 0.992314 00:32
29 0.008912 0.042545 0.988898 00:33
30 0.009870 0.039163 0.988044 00:33
31 0.010967 0.018784 0.992314 00:32
32 0.006510 0.021688 0.991887 00:32
33 0.006636 0.033374 0.991460 00:32
34 0.010336 0.020198 0.993595 00:33
35 0.009317 0.030448 0.991033 00:33
36 0.007046 0.022307 0.993168 00:33
37 0.009590 0.026956 0.990606 00:32
38 0.006269 0.055886 0.985056 00:32
39 0.010013 0.018850 0.994449 00:32
40 0.008058 0.027818 0.993168 00:33
41 0.007327 0.015476 0.993595 00:32
42 0.006886 0.010855 0.997011 00:32
43 0.011692 0.017141 0.997011 00:32
44 0.007462 0.030888 0.990179 00:32
45 0.006464 0.015794 0.992741 00:32
46 0.007760 0.068463 0.984628 00:33
47 0.006637 0.015711 0.993168 00:32
48 0.010105 0.041067 0.988898 00:32
49 0.007672 0.012651 0.996157 00:33
50 0.014199 0.083004 0.974381 00:33
51 0.012289 0.018203 0.993168 00:32
52 0.009026 0.020449 0.994022 00:32
53 0.004553 0.017501 0.993595 00:32
54 0.010326 0.024923 0.991033 00:32
55 0.015319 0.027962 0.992314 00:33
56 0.004357 0.023815 0.994022 00:32
57 0.005287 0.019874 0.992314 00:32
58 0.009573 0.014026 0.995730 00:32
59 0.006735 0.021964 0.993168 00:32
60 0.005811 0.023319 0.990606 00:33
61 0.011406 0.026691 0.992741 00:32
62 0.005277 0.022868 0.994449 00:33
63 0.006119 0.018390 0.994022 00:33
64 0.007875 0.034545 0.994022 00:32
65 0.005800 0.020408 0.994022 00:32
66 0.002680 0.019692 0.994449 00:32
67 0.006419 0.034546 0.991033 00:33
68 0.006348 0.053590 0.986763 00:32
69 0.005590 0.031790 0.993595 00:32
70 0.007865 0.029411 0.994876 00:33
71 0.002760 0.026847 0.993168 00:32
72 0.009839 0.030372 0.992741 00:32
73 0.008680 0.026388 0.992314 00:32
74 0.004330 0.031201 0.992741 00:32
75 0.009632 0.078810 0.984202 00:32
76 0.003771 0.022387 0.992741 00:33
77 0.006113 0.030133 0.992314 00:32
78 0.003496 0.028839 0.995303 00:33
79 0.003018 0.026174 0.994022 00:33
80 0.007461 0.030011 0.993595 00:33
81 0.004392 0.023791 0.994876 00:32
82 0.005972 0.068508 0.987617 00:32
83 0.006191 0.019870 0.996584 00:33
84 0.005330 0.020402 0.996584 00:33
85 0.002982 0.036186 0.993168 00:32
86 0.003956 0.019152 0.994022 00:32
87 0.006709 0.022051 0.994449 00:32
88 0.004887 0.043770 0.991460 00:32
89 0.004027 0.025353 0.993168 00:32
90 0.002959 0.029085 0.992741 00:33
91 0.003077 0.025070 0.993595 00:33
92 0.004699 0.024857 0.992741 00:32
93 0.002660 0.032952 0.995730 00:32
94 0.003100 0.025073 0.994876 00:32
95 0.002563 0.023130 0.994022 00:32
96 0.001407 0.023987 0.995730 00:33
97 0.002879 0.015754 0.996584 00:33
98 0.002273 0.019964 0.995730 00:33
99 0.001539 0.023395 0.994022 00:32
100 0.002776 0.019369 0.997438 00:32
101 0.001925 0.015023 0.996157 00:32
102 0.002006 0.039217 0.991887 00:32
103 0.003615 0.011737 0.997011 00:32
104 0.002477 0.016405 0.995730 00:33
105 0.001914 0.014328 0.997438 00:32
106 0.000848 0.020702 0.995730 00:32
107 0.005377 0.028292 0.994022 00:32
108 0.003150 0.019413 0.996584 00:32
109 0.001558 0.022858 0.995730 00:33
110 0.002981 0.022044 0.995730 00:32
111 0.003152 0.024832 0.993595 00:32
112 0.001988 0.016285 0.995730 00:33
113 0.000533 0.014695 0.995730 00:33
114 0.000902 0.017304 0.995730 00:33
115 0.001843 0.019725 0.995730 00:33
116 0.001038 0.020030 0.995730 00:33
117 0.000729 0.019264 0.994022 00:32
118 0.001277 0.027110 0.994876 00:33
119 0.001734 0.026816 0.993168 00:32
120 0.002050 0.020589 0.995730 00:33
121 0.002221 0.022525 0.995730 00:33
122 0.000572 0.027818 0.993168 00:33
123 0.001051 0.018991 0.994876 00:33
124 0.000295 0.019816 0.994876 00:32
125 0.001252 0.022995 0.995730 00:33
126 0.000770 0.021016 0.994449 00:33
127 0.000683 0.030154 0.994449 00:32
128 0.003303 0.026239 0.995730 00:33
129 0.001704 0.025088 0.994022 00:33
130 0.002516 0.010910 0.996584 00:33
131 0.000699 0.015325 0.996584 00:33
132 0.000870 0.013863 0.996584 00:32
133 0.000663 0.020103 0.995730 00:33
134 0.000980 0.012507 0.996584 00:32
135 0.000181 0.014895 0.995730 00:33
136 0.000645 0.030882 0.994022 00:33
137 0.000258 0.029726 0.994022 00:33
138 0.000154 0.019418 0.995730 00:33
139 0.000699 0.019971 0.995730 00:32
140 0.000355 0.024038 0.994876 00:32
141 0.000170 0.030813 0.994876 00:33
142 0.000657 0.027899 0.994876 00:32
143 0.001425 0.024708 0.995730 00:33
144 0.000381 0.020135 0.994022 00:32
145 0.000152 0.025634 0.994876 00:33
146 0.000075 0.018921 0.994876 00:33
147 0.000226 0.017673 0.994876 00:33
148 0.000224 0.023066 0.996584 00:32
149 0.000632 0.018082 0.994876 00:33
150 0.000625 0.016179 0.996584 00:32
151 0.000080 0.021201 0.994876 00:32
152 0.000068 0.021460 0.994022 00:32
153 0.000112 0.018794 0.995730 00:33
154 0.000080 0.021812 0.994876 00:32
155 0.000040 0.018293 0.995730 00:32
156 0.000171 0.018570 0.997438 00:33
157 0.000175 0.015313 0.996584 00:32
158 0.000464 0.016535 0.996584 00:33
159 0.000109 0.019572 0.996584 00:33
160 0.000062 0.021594 0.996584 00:33
161 0.000064 0.014384 0.996584 00:32
162 0.000014 0.020526 0.996584 00:33
163 0.000028 0.019420 0.995730 00:32
164 0.000042 0.030555 0.994876 00:33
165 0.000080 0.022019 0.996584 00:33
166 0.000079 0.030117 0.994876 00:33
167 0.000038 0.019891 0.996584 00:33
168 0.000027 0.024130 0.996584 00:33
169 0.000017 0.027270 0.995730 00:32
170 0.000032 0.018282 0.995730 00:33
171 0.000062 0.019155 0.996584 00:33
172 0.000059 0.023948 0.995730 00:32
173 0.000011 0.025428 0.995730 00:33
174 0.000011 0.019787 0.995730 00:33
175 0.000018 0.025644 0.995730 00:33
176 0.000185 0.021899 0.995730 00:33
177 0.000056 0.021866 0.995730 00:33
178 0.000061 0.022560 0.995730 00:33
179 0.000019 0.019159 0.995730 00:33
180 0.000009 0.024180 0.995730 00:33
181 0.000030 0.022470 0.995730 00:33
182 0.000007 0.020468 0.995730 00:32
183 0.000049 0.024680 0.995730 00:33
184 0.000009 0.019799 0.994876 00:32
185 0.000026 0.025008 0.995730 00:33
186 0.000028 0.029448 0.995730 00:33
187 0.000161 0.032871 0.995730 00:32
188 0.000334 0.028276 0.995730 00:33
189 0.000033 0.023425 0.995730 00:32
190 0.000012 0.027646 0.995730 00:33
191 0.000012 0.026857 0.995730 00:33
192 0.000120 0.025125 0.995730 00:33
193 0.000014 0.029498 0.995730 00:33
194 0.000010 0.028255 0.995730 00:33
195 0.000098 0.027213 0.995730 00:33
196 0.000031 0.024639 0.995730 00:33
197 0.000021 0.028268 0.995730 00:32
198 0.000005 0.021215 0.995730 00:32
199 0.000010 0.027356 0.995730 00:32

CAM 결과 확인_에폭 200

fig, ax = plt.subplots(5,5) 
k=0 
for i in range(5):
    for j in range(5): 
        x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
        camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
        a,b = net(x).tolist()[0]
        normalprob, pneumoniaprob = np.exp(a)/ (np.exp(a)+np.exp(b)) ,  np.exp(b)/ (np.exp(a)+np.exp(b)) 
        if normalprob>pneumoniaprob: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("normal(%s)" % normalprob.round(5))
        else: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("pneumonia(%s)" % pneumoniaprob.round(5))
        k=k+1 
fig.set_figwidth(16)            
fig.set_figheight(16)
fig.tight_layout()
fig, ax = plt.subplots(5,5) 
k=3000 
for i in range(5):
    for j in range(5): 
        x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
        camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
        a,b = net(x).tolist()[0]
        normalprob, pneumoniaprob = np.exp(a)/ (np.exp(a)+np.exp(b)) ,  np.exp(b)/ (np.exp(a)+np.exp(b)) 
        if normalprob>pneumoniaprob: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("normal(%s)" % normalprob.round(5))
        else: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("pneumonia(%s)" % pneumoniaprob.round(5))
        k=k+1 
fig.set_figwidth(16)            
fig.set_figheight(16)
fig.tight_layout()

SAMPLE

get_image_files(path)[0]
Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0766-0001.jpeg')
img = PILImage.create(get_image_files(path)[0])
img
x, = first(dls.test_dl([img]))  #이미지 텐서화
x.shape
torch.Size([1, 3, 224, 224])

판단 근거가 강할수록 파란색 $\to$ 보라색 변함

a=net(x.to("cpu")).tolist()[0][0]
b=net(x.to("cpu")).tolist()[0][1]
np.exp(a)/(np.exp(a)+np.exp(b)), np.exp(b)/(np.exp(a)+np.exp(b))
(0.9999999999930361, 6.9638572672406385e-12)
camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x.to("cpu")).squeeze())
fig, (ax1,ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
test=camimg[0]-torch.min(camimg[0])
A1=torch.exp(-0.04*test)
A2=1-A1
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A2.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("MODE1 WEIGHT")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A1.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("MODE1 RES WEIGHT")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
  • $\theta$ 가 작아질수록 범위가 좁아지는? 경향
X1=np.array(A1.to("cpu").detach(),dtype=np.float32)
Y1=torch.Tensor(cv2.resize(X1,(224,224),interpolation=cv2.INTER_LINEAR))
x1=x.squeeze().to('cpu')*Y1
X12=np.array(A2.to("cpu").detach(),dtype=np.float32)
Y12=torch.Tensor(cv2.resize(X12,(224,224),interpolation=cv2.INTER_LINEAR))
x12=x.squeeze().to('cpu')*Y12
  • 1st CAM 결과를 분리하면 아래와 같음.
fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
(x12*0.5).squeeze().show(ax=ax1)  #MODE1
(x1*0.5).squeeze().show(ax=ax2)  #MODE1_res
ax1.set_title("MODE1")
ax2.set_title("MODE1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
x1=x1.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)
camimg1 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x1).squeeze())
  • CAM

    • mode1_res에 CAM 결과 올리기
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
x1.squeeze().show(ax=ax1)
ax1.imshow(camimg1[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
x1.squeeze().show(ax=ax2)
ax2.imshow(camimg1[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(camimg1[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
a1=net(x1).tolist()[0][0]
b1=net(x1).tolist()[0][1]
np.exp(a1)/(np.exp(a1)+np.exp(b1)), np.exp(b1)/(np.exp(a1)+np.exp(b1))
(0.9905858008146032, 0.00941419918539672)
ver2 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x1).squeeze())
  • CAM
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
xx1.squeeze().show(ax=ax1)
ax1.imshow(ver2[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
xx1.squeeze().show(ax=ax2)
ax2.imshow(ver2[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(ver2[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
a1=net(x1).tolist()[0][0]
b1=net(x1).tolist()[0][1]
np.exp(a1)/(np.exp(a1)+np.exp(b1)), np.exp(b1)/(np.exp(a1)+np.exp(b1))
(0.9905858008146032, 0.00941419918539672)
  • 이 정도면 괜찮은 것 같다.
  • 첫번째 CAM에서 정상 판단 근거였던 폐의 가운데 부분이 어두워지자 약간 오른쪽 폐로 이동한 모습.
  • CAT/DOG 예제에서 $\theta$를 2배씩 늘려나갔으나, 여기서는 $\theta$를 이전과 동일하게 유지함.
test1=ver2[0]-torch.min(ver2[0])
A3=torch.exp(-0.04*test1)  
A4=1-A3
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A3.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("MODE2 WEIGHT WITH THETA=0.04")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A4.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("MODE2 RES WEIGHT WITH THETA=0.04")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
X3=np.array(A3.to("cpu").detach(),dtype=np.float32)
Y3=torch.Tensor(cv2.resize(X3,(224,224),interpolation=cv2.INTER_LINEAR))
x3=x.squeeze().to('cpu')*Y1*Y3
X4=np.array(A4.to("cpu").detach(),dtype=np.float32)
Y4=torch.Tensor(cv2.resize(X4,(224,224),interpolation=cv2.INTER_LINEAR))
x4=x.squeeze().to('cpu')*Y1*Y4
  • 2nd CAM 결과를 분리하면 아래와 같음.
fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
x12.squeeze().show(ax=ax1)  
x1.squeeze().show(ax=ax2)  
ax1.set_title("MODE1")
ax2.set_title("MODE1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
x4.squeeze().show(ax=ax1)  
x3.squeeze().show(ax=ax2)  
ax1.set_title("MODE2")
ax2.set_title("MODE2 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
x3=x3.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)
ver3 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x3).squeeze())
  • CAM
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
x3.squeeze().show(ax=ax1)
ax1.imshow(ver3[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
x3.squeeze().show(ax=ax2)
ax2.imshow(ver3[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
fig, (ax1,ax2, ax3) = plt.subplots(1,3) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(ver2[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax3)
ax3.imshow(ver22[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax3.set_title("3RD CAM")
#

fig.set_figwidth(12)            
fig.set_figheight(12)
fig.tight_layout()
a2=net(x3).tolist()[0][0]
b2=net(x3).tolist()[0][1]
np.exp(a2)/(np.exp(a2)+np.exp(b2)), np.exp(b2)/(np.exp(a2)+np.exp(b2))
(8.003716145531185e-16, 0.9999999999999991)

$\theta=0.055$

✏️ CAM에서 발견한 특징의 가중치를 낮추는 mode weigth와 CAM에서 발견한 특징을 제외한 부분의 가중치를 살리는 mode res weight를 생성해 mode$n$($n$=number of acting CAM)과 mode$n$ res를 생성한다.
이 때 k라는 hyper parameter값을 곱해줘서 그 정도를 조정하는데,

  • k의 값이 클수록, CAM img의 가중치에 큰 값이 곱해져 이미지가 어떤 클래스로 분류되는데 기여한 부분의 범위가 넓어지며(=mode weight에서 분홍색으로 표시되는 부위가 실제 CAM img보다 넓어짐), residual에서는 더욱 어둡게 나타난다.
  • k의 값이 작을수록, CAM img의 가중치에 큰 값이 곱해져 이미지가 어떤 클래스로 분류되는데 기여한 부분의 범위가 CAM img에서 나타난 것과 유사하나, residual에서 완벽하게 지워지지가 않아서 다음 차수에서 동일한 부분이 탐색될 가능성이 높다.

위와 같은 이유로 적당한 k를 시뮬레이션을 통해 임의로 설정하였다.

  • k=0.04로 진행해본 결과, 이전 언급한 것과 같이 1차에서 발견된 특징이 잘 지워지지 않았다. 따라서 2번째 CAM에서 발견된 특징이 1번째 CAM에서 발견된 특징과 유사하다.
  • k=0.055로 진행해본 결과, 1차에서 발견된 특징이 residual img에서 잘 지워졌으나(검정색에 가깝게 나타났으나, 픽셀의 값이 0에 가깝게 나타났으나) 폐 내부?의 시각적인 특징들도 함께 사라져 2번째 CAM에서 엉뚱한 결과를 초래했다.
test=camimg[0]-torch.min(camimg[0])
A1=torch.exp(-0.055*test)
A2=1-A1
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A2.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("MODE1 WEIGHT")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A1.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("MODE1 RES WEIGHT")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
X1=np.array(A1.to("cpu").detach(),dtype=np.float32)
Y1=torch.Tensor(cv2.resize(X1,(224,224),interpolation=cv2.INTER_LINEAR))
x1=x.squeeze().to('cpu')*Y1
X12=np.array(A2.to("cpu").detach(),dtype=np.float32)
Y12=torch.Tensor(cv2.resize(X12,(224,224),interpolation=cv2.INTER_LINEAR))
x12=x.squeeze().to('cpu')*Y12
  • 1st CAM 결과를 분리하면 아래와 같음.
fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
(x12*0.5).squeeze().show(ax=ax1)  #MODE1
(x1*0.5).squeeze().show(ax=ax2)  #MODE1_res
ax1.set_title("MODE1")
ax2.set_title("MODE1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
x1=x1.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)
camimg1 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x1).squeeze())
  • CAM

    • mode1_res에 CAM 결과 올리기
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
x1.squeeze().show(ax=ax1)
ax1.imshow(camimg1[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
x1.squeeze().show(ax=ax2)
ax2.imshow(camimg1[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(camimg1[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
a1=net(x1).tolist()[0][0]
b1=net(x1).tolist()[0][1]
np.exp(a1)/(np.exp(a1)+np.exp(b1)), np.exp(b1)/(np.exp(a1)+np.exp(b1))
(2.165395603263705e-07, 0.9999997834604397)

  • 전체 그림에 적용하기 (n=11712)
fig, ax = plt.subplots(5,5) 
k=3000 
for i in range(5):
    for j in range(5): 
        x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
        camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
        a,b = net(x).tolist()[0]
        normalprob, pneumoniaprob = np.exp(a)/ (np.exp(a)+np.exp(b)) ,  np.exp(b)/ (np.exp(a)+np.exp(b)) 
        if normalprob>pneumoniaprob: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("normal(%s)" % normalprob.round(5))
        else: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("pneumonia(%s)" % pneumoniaprob.round(5))
        k=k+1 
fig.set_figwidth(16)            
fig.set_figheight(16)
fig.tight_layout()
import pandas as pd
col=pd.DataFrame()
k=0
col=[]
for k in range(5) :
    col[k]=print(k)
    k=k+1
0
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/tmp/ipykernel_170705/2694417875.py in <module>
      2 col=[]
      3 for k in range(5) :
----> 4     col[k]=print(k)
      5     k=k+1

IndexError: list assignment index out of range

interp = ClassificationInterpretation.from_learner(lrnr2)
interp.plot_confusion_matrix()
#cleaner   #잘못 예측한 이미지 제거_제거될 이미지를 보여주는 것 같음